{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from initial_dataSet import DataSet\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.cluster import KMeans\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataSet: KDDCUP\n"
     ]
    }
   ],
   "source": [
    "basedir = '../'\n",
    "baseline_list=('DINA','NCD')\n",
    "baseline=baseline_list[0]\n",
    "\n",
    "dataSet_list = ('ASSIST_0910', 'ASSIST_2017', 'JUNYI', 'MathEC', 'KDDCUP')\n",
    "save_list = ('a0910/', 'a2017/', 'junyi/', 'math_ec/', 'kddcup/')\n",
    "\n",
    "dataSet_idx=4\n",
    "data_set_name = dataSet_list[dataSet_idx]\n",
    "dataSet = DataSet(basedir, data_set_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_item_concept_df(item:pd.DataFrame) -> pd.DataFrame:\n",
    "    item = item[~item.index.duplicated()]  # 去除重复项\n",
    "    item_list, concept_list = [], []\n",
    "    for idx in item.index:\n",
    "        now_concept_list = eval(item.loc[idx, 'knowledge_code'])\n",
    "        item_list.extend([idx] * len(now_concept_list))\n",
    "        concept_list.extend(now_concept_list)\n",
    "    return pd.DataFrame({'item': item_list, 'concept': concept_list}).astype('int')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "read_dir='./'+baseline+'/output/'+save_list[dataSet_idx]\n",
    "\n",
    "cogn_state=np.loadtxt(read_dir+'cognitive_state.csv',delimiter=',')\n",
    "record=dataSet.record.reset_index()\n",
    "item_conc=get_item_concept_df(dataSet.item)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 16/16 [00:21<00:00,  1.35s/it]\n"
     ]
    }
   ],
   "source": [
    "doa_list=[]\n",
    "for k in tqdm(range(1,cogn_state.shape[1]+1)):\n",
    "    cong_doa_count=0\n",
    "    item_cong_doa_count=0\n",
    "    item_k=item_conc[item_conc['concept']==k-1]['item'].unique()\n",
    "    record_k=record.set_index('item_id').loc[item_k,:]\n",
    "    if len(item_k)>0 and len(record_k)>0:\n",
    "        for j in item_k:\n",
    "            record_j=record_k.loc[j,:]\n",
    "            stu_j=np.array(record_j['user_id']).reshape(-1).astype('int')\n",
    "            cong_j=cogn_state[stu_j-1,k-1]\n",
    "            sort_stu_idx=cong_j.argsort() # 从小到大\n",
    "            sort_score_j=np.array(record_j['score']).reshape(-1)[sort_stu_idx]\n",
    "            sort_cong_j=cong_j[sort_stu_idx]\n",
    "            for i in range(len(sort_stu_idx)):\n",
    "                fliter=((sort_cong_j[i:]-sort_cong_j[i])>=0)\n",
    "                cong_doa_count+=fliter.sum()\n",
    "                score_i=sort_score_j[i:][fliter]\n",
    "                if len(score_i)>0:\n",
    "                    item_cong_doa_count+=((score_i-sort_score_j[i])>=0).sum()\n",
    "        if cong_doa_count>0:\n",
    "            doa_list.append(item_cong_doa_count/cong_doa_count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DINA\n",
      "KDDCUP 的DOA = 0.9052007940936431\n"
     ]
    }
   ],
   "source": [
    "print(baseline)\n",
    "print('{} 的DOA = {}'.format(data_set_name,np.mean(doa_list)))"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "527a93331b4b1a8345148922acc34427fb7591433d63b66d32040b6fbbc6d593"
  },
  "kernelspec": {
   "display_name": "Python 3.8.12 ('pytorch')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
